% This code solves the model 
function [p1_all, p2_all, n_fringes_mkt, ...
    profit_all, non_converge_all, solve_market_all, fsolve_all, ...
    E_value_y, E_value, enrollment, enrollment_m, enrollment_f, enrollment_per_m, enrollment_per_f, non_converge, solve_market,...
    s1_mkt_all, s2_mkt_all, rs_share_all]...
    = solve_eq(alpha, ...
    Delta_xi_all, cost1_e_all, cost2_1_all,  mu_f_all, sigma_f_all,...
    cost2_0, D_input, N, ...
    K, pr_y, N_y, resource_y1, resource_y2, y_medi, x_fc, pr_fc, ...
    rho, crra, sigma_c1, beta_c, beta_f, T1, B1_c, B2_c, B1_f, B2_f, delta, ...
    lr_target, N_f, entry_exitflag_vec, id_ist_old, myopic, Er,...
    p1_ini, p2_ini, n_ini, cost_weight, use_cost_weight)
%% notations
% N.ist: # majors + # markets 
% N_ist: # majors + # fringes (unmerged version)

%% consumer's value for outside option
state_pr_s = unique([D_input.state D_input.pr_k], 'rows');
lapse_u = 0; % lapses are exogenous 
for ss = 1:N.s
    
    pr_s = state_pr_s(ss,2:K+1)';  % pr_s: K x 1
    
    % consumer_v0_ftn called from demand estimation folder
    [v0, U0] = consumer_v0_ftn(N_y, rho, crra, beta_c, K, ...
        pr_s, x_fc, pr_fc, y_medi, B1_c, B2_c, T1, alpha,...
        lapse_u, resource_y1, resource_y2);
    
    S(ss).v0 = v0; % N_y x 1
    S(ss).U0 = U0; % N_y x K
end

%% p1_target: determined by observed claims & target loss ratio
D_input_pr_s = D_input.pr_s;
D_input_claims2 = D_input.claims2;

% p1_target_ftn called from supply estimation folder
p1_target_all = p1_target_ftn(lr_target, D_input_pr_s, D_input_claims2, ...
    beta_f, T1, B2_f,  B1_f, Er);

%% solve the model for each market
KM = size(D_input.p2,2);

% array to store eqm choices for each firm: for fringes, just store one
% value in each market (symmetry)
p1_all = zeros(N.ist, 1);
p2_all = zeros(N.ist, KM);
profit_all = zeros(N.ist, 1);
s1_mkt_all = zeros(N.ist, 1);
s2_mkt_all = zeros(N.ist, KM);
non_converge_all = zeros(N.ist,1);
solve_market_all = zeros(N.ist,1);
fsolve_all       = zeros(N.ist,1); 
rs_share_all = zeros(N.ist,1);

% array to store other eqm outcomes at the market level
E_value_y     = zeros(N.st, N_y); % consumer's expected value conditional on y
E_value       = zeros(N.st, 1); % consumer's expected value unconditional on y
enrollment    = zeros(N.st, 1); % total insured rate in the market
enrollment_m    = zeros(N.st, 1); % total insured rate from major firms in the market
enrollment_f   = zeros(N.st, 1); % total insured rate from fringes in the market
enrollment_per_m    = zeros(N.st, 1); % mean insured rate per major firm in the market
enrollment_per_f   = zeros(N.st, 1); % mean insured rate per fringe in the market
non_converge  = zeros(N.st, 1); % 1 if no fixed pt found
solve_market  = zeros(N.st, 1); % 1 if the market's cost estimates make sense and fixed pt is tried
n_fringes_mkt = zeros(N.st, 1);

% bug checking
if N.st ~= max(D_input.mkt)
    disp('ERROR in indexing markets, D_input.mkt')
end

% fixed point settings
tol      = 1e-3; 
max_iter = 5000;
weight   = 0.2;

tolf = 1e-11; % fsolve tolerance
options = optimoptions('fsolve','Display','none', ...
    'FunctionTolerance', tolf, 'OptimalityTolerance', tolf,...
    'MaxIterations', 1000, ...
    'MaxFunctionEvaluations', 2000,...
    'StepTolerance', tolf);

max_p = min([resource_y1; resource_y2])-1;

% loop through each market (calendar year x geographical state combination)
for mm = 1:N.st
    %% indices for firms in market mm
    % index for majors
    iind_m = find(D_input.mkt==mm & D_input.major==1); % iind_m, = 1,..., N_ist
    NN_m = length(iind_m);

    % index for fringes 
    iind_f = find(D_input.mkt==mm & D_input.major==0); % iind_f = 1,..., N_ist
    
    % needed to store in N.ist arrays with the original sorting as D
    id_ist_m = D_input.id_ist(iind_m); % NN_m x 1
    id_ist_f = unique(D_input.id_ist(iind_f)); % 1 x 1
    if length(id_ist_f)~=1
        disp('EROR in length(id_ist_f)')
    end

    %% exogenous values 
    state_mm  = unique(D_input.state(iind_m)); % geographical state
    
    pr_s_row = D_input.pr_s(iind_m(1),:); % 1 x KM: the prob dist of agg states

    D_claims2_m = D_input.claims2(iind_m,:); % NN_m x KM
    D_claims2_f = D_input.claims2(iind_f(1),:); % 1 x KM

    p1_target_m = p1_target_all(iind_m,:);
    p1_target_f = p1_target_all(iind_f(1),:);

    Delta_xi_m = Delta_xi_all(iind_m,:);
    Delta_xi_f = Delta_xi_all(iind_f(1),:);
    
    cost1_e_m = cost1_e_all(iind_m,:);
    cost1_e_f = cost1_e_all(iind_f(1),:);

    cost2_1_m = cost2_1_all(iind_m,:);
    cost2_1_f = cost2_1_all(iind_f(1),:);

    cost2_0_m = cost2_0(iind_m,:); % NN_m x KM
    cost2_0_f = cost2_0(iind_f(1),:); % 1 x KM

    % the market's entry cost distribution = logN(mu_f_mm, sigma_f_mm)
    mu_f_mm    = unique(mu_f_all(iind_f));    % mu_f_all: N_ist x 1
    sigma_f_mm = unique(sigma_f_all(iind_f)); % sigma_f_all: N_ist x 1

    % outside option 
    v0_mm = S(state_mm).v0;  % N_y x 1
    U0_mm = S(state_mm).U0;  % N_y x K

    %% initial guess on optimal choices
    p1_m = p1_ini(iind_m,:); % NN_m x 1
    p1_f = p1_ini(iind_f(1),:); % just one row for the symmetric fringe
    p2_m = p2_ini(iind_m,:); % NN_m x KM
    p2_f = p2_ini(iind_f(1),:); % just one row for the symmetric fringe

    n_variety = n_ini(mm); % 1 x 1, # fringes in the market

    n_variety_c = n_variety; 

    % array to store updated optimal choices
    p1_m_new = zeros(size(p1_m)); % NN_m x 1
    p1_f_new = zeros(size(p1_f)); % 1 x 1
    p2_m_new = zeros(size(p2_m)); % NN_m x KM
    p2_f_new = zeros(size(p2_f)); % 1 x KM

    fsolve_m_new = zeros(NN_m,1);
    fsolve_f_new = zeros(1,1);

    p1_m(p1_m>max_p)=max_p;
    p1_f(p1_f>max_p)=max_p;
    p2_m(p2_m>max_p)=max_p;
    p2_f(p2_f>max_p)=max_p;

    %% fixed point for optimal choices
    if mean([cost1_e_m; cost1_e_f])<0 || entry_exitflag_vec(mm)<=0 % markets not to be solved
        % insurer-level values: save for majors
        for jj=1:NN_m
            ll=find(id_ist_m(jj)==id_ist_old);
            solve_market_all(ll,:) = 0;
        end
        % insurer-level values: save one row for fringe
        ll = find(id_ist_f==id_ist_old);
        solve_market_all(ll,:) = 0;

        % market-level values
        solve_market(mm) = 0;

    else % cost estimates fine so solve
        diff = 1;
        iter = 0;

        while (diff>tol && iter<max_iter)
            
            iter=iter+1;

            %% demand using (p1, p2, n_variety) from the previous iteration

            % # major firms + # fringes -> updated each iteration
            NN = NN_m + n_variety;

            % stack majors followed by repeated fringes 
            p1 = [p1_m; repmat(p1_f,n_variety,1)]; % NN x 1
            p2 = [p2_m; repmat(p2_f,n_variety,1)]; % NN x KM
            Delta_xi = [Delta_xi_m; repmat(Delta_xi_f, n_variety,1)];

            % correct beliefs
            % - v_c: N_y x NN
            p2_input = p2; % NN x KM
            [v_c, ~] = predict_v(p1, p2_input, ...
                alpha,  Delta_xi,...
                N_y, resource_y1, resource_y2, pr_s_row, KM,...
                rho, crra, beta_c, U0_mm, B1_c, B2_c, T1, delta);

            % misinformed beliefs
            % - v_m: N_y x NN
            p2_input = repmat(p1, 1, KM); % NN x KM
            [v_m, ~] = predict_v(p1, p2_input, ...
                alpha,  Delta_xi,...
                N_y, resource_y1, resource_y2, pr_s_row, KM,...
                rho, crra, beta_c, U0_mm, B1_c, B2_c, T1, delta);

            % choice probabilities including outside option -> per-fringe
            % shares are computed 
            ind_temp = ones(1,NN);
            s1_ind_c = share_ind_ftn(v_c, ind_temp, v0_mm, sigma_c1); %N_y x NN+1
            s1_ind_m = share_ind_ftn(v_m, ind_temp, v0_mm, sigma_c1); %N_y x NN+1

            % N_y x NN+1
            s1_ind = (1-myopic)*s1_ind_c + myopic*s1_ind_m;

            s1_ind1 = s1_ind(:,1:end-1); % N_y x NN (inside options only)
            s1_mkt = sum(s1_ind1.*repmat(pr_y, 1, NN))'; % NN x 1, integrated over y
            s1_mkt_m = s1_mkt(1:NN_m,:);
            s1_mkt_f = s1_mkt(NN_m+1,:);

            % 2nd-period market share (lapses are exogenous)
            s2_mkt = repmat(s1_mkt, 1, KM)*(1-unique(delta));
            s2_mkt_m = s2_mkt(1:NN_m,:); 
            s2_mkt_f = s2_mkt(NN_m+1,:);

            %% add risk-weighted average claims -> will be used in profit calculation
            pr_stay_fixed2 = (1-unique(delta))*ones(N_y,KM);

            D_claims_av_m=zeros(NN_m, KM);
            D_claims_av_f=zeros(1,KM);
            for j=1:NN_m
                for k=1:KM
                    D_claims2_m_jj=D_claims2_m(j,k);
                    s1_ind1_jj=s1_ind1(:,j); % N_y x 1
                    s2_ind1_jj=repmat(s1_ind1_jj, 1, KM).* pr_stay_fixed2; % (N_y, KM) matrix
                    D_claims_av_m(j,k) = sum(cost_weight.*repmat(D_claims2_m_jj, 1, N_y)'.*s2_ind1_jj(:,k).*pr_y, 'all'); %sum (N_y) at firm level
                end
            end
            s1_ind1_jj=s1_ind1(:,NN_m+1); %N_y x 1
            s2_ind1_jj=repmat(s1_ind1_jj, 1, KM).* pr_stay_fixed2; % (N_y, KM) matrix
            for k=1: KM
                D_claims2_f_jj=D_claims2_f(k);
                D_claims_av_f(k) = sum(cost_weight.*repmat(D_claims2_f_jj, 1, N_y)'.*s2_ind1_jj(:,k).*pr_y, 'all'); %sum (N_y) at firm level
            end

            %% update p2 to p2_new
            % majors
            for j=1:NN_m
                for k=1:KM
                    if cost2_1_m(j,k)>= (s2_mkt_m(j,k)^2)/(2*cost2_0_m(j,k))
                        p2_m_new(j,k)=p1_m(j);
                    else
                        p2_m_new(j,k)=p1_m(j)+ s2_mkt_m(j,k)/cost2_1_m(j,k);
                    end
                end
            end

            % for fringe, just one row
            j=1;
            for k=1:KM
                if cost2_1_f(j,k)>= (s2_mkt_f(j,k)^2)/(2*cost2_0_f(j,k))
                    p2_f_new(j,k)=p1_f(j);
                else
                    p2_f_new(j,k)=p1_f(j)+ s2_mkt_f(j,k)/cost2_1_f(j,k);
                end
            end

            p2_m_new(p2_m_new>max_p)=max_p;
            p2_f_new(p2_f_new>max_p)=max_p;

            %% update p1 to p1_new
            % loop thru majors
            for j=1:NN_m
                
                Delta_xi_jj  = Delta_xi_m(j); % scalar
                p1_target_jj = p1_target_m(j); % scalar
                p2_jj        = p2_m(j,:); % 1 x KM

                cost1_e_jj   = cost1_e_m(j,:); % scalar
                cost2_1_jj   = cost2_1_m(j,:); % 1 x KM
                D_claims2_jj = D_claims2_m(j,:); % 1 x KM

                foc_eval=@(x0)foc_1st(x0, p2_jj, ...
                    cost1_e_jj, cost2_1_jj, j, v_c, v_m, v0_mm, p1_target_jj,...
                    D_claims2_jj, pr_s_row, ...
                    alpha, Delta_xi_jj, ...
                    KM, resource_y1, resource_y2, N_y, rho, crra, U0_mm, pr_y, sigma_c1, ...
                    myopic, beta_c, beta_f, T1, B1_c, B2_c, B1_f, B2_f, delta, cost_weight, use_cost_weight);

                x0 = p1_m(j);
                [x1_new, fval, exitflag] = fsolve(foc_eval, x0, options);
                p1_m_new(j,:) = x1_new(1);
                fsolve_m_new(j) = (exitflag>0); % 1 if equation solved
            end

            % fringe
            j=1; % just store one row
            Delta_xi_jj  = Delta_xi_f(j); % scalar
            p1_target_jj = p1_target_f(j); % scalar
            p2_jj        = p2_f(j,:); % 1 x KM

            cost1_e_jj   = cost1_e_f(j,:); % scalar
            cost2_1_jj   = cost2_1_f(j,:); % 1 x KM
            D_claims2_jj = D_claims2_f(j,:); % 1 x KM

            foc_eval=@(x0)foc_1st(x0, p2_jj, ...
                cost1_e_jj, cost2_1_jj, j, v_c, v_m, v0_mm, p1_target_jj,...
                D_claims2_jj, pr_s_row, ...
                alpha, Delta_xi_jj, ...
                KM, resource_y1, resource_y2, N_y, rho, crra, U0_mm, pr_y, sigma_c1, ...
                myopic, beta_c, beta_f, T1, B1_c, B2_c, B1_f, B2_f, delta, cost_weight, use_cost_weight);

            x0 = p1_f(j);
            [x1_new, fval, exitflag] = fsolve(foc_eval, x0, options);
            p1_f_new(j,:) = x1_new(1);
            fsolve_f_new(j) = (exitflag>0); % 1 if equation solved

            p1_m_new(p1_m_new>max_p)=max_p;
            p1_f_new(p1_f_new>max_p)=max_p;

            %% update n_variety to n_variety_new for fringe firms only
            % rate stability regulation cost at (p1_f, p2_f)
            delta_p = p2_f-repmat(p1_f,1,KM); % 1 x KM
            cost_rs_temp = cost2_0_f + 0.5*cost2_1_f.*(delta_p.^2);
            cost_rs_temp = cost_rs_temp(delta_p>0);
            cost_rs = zeros(size(delta_p)); % 1 x KM
            cost_rs(delta_p>0) = cost_rs_temp;
            
            % min loss ratio regulation cost
            p1_deviation = p1_f-p1_target_f;
            cost_ml = 0.5*cost1_e_f.*(p1_deviation.^2); % 1 x 1
            
            % fringe's profit 
            profit1 = B1_f*p1_f.*s1_mkt_f;

            if use_cost_weight==0
                profit2_k = (p2_f-D_claims2_f).*s2_mkt_f - cost_rs;
            else
                profit2_k = p2_f.*s2_mkt_f-D_claims_av_f - cost_rs;
            end

            profit2 = (beta_f^T1)*B2_f*...
                sum(repmat(pr_s_row, size(profit2_k,1),1).*profit2_k,2);
            Ecost_rs_f = (beta_f^T1)*B2_f*...
                sum(repmat(pr_s_row, size(profit2_k,1),1).*cost_rs,2);
            profit_f = profit1 + profit2 - cost_ml; % 1 x 1
         
            % update
            n_variety_c_new = N_f*logncdf(profit_f, mu_f_mm, sigma_f_mm); % not an integer

            %% evaluate the difference
            diff = mean(([p1_m_new; p1_f_new] - [p1_m; p1_f]).^2) ...
                + mean(mean(([p2_m_new; p2_f_new] - [p2_m; p2_f]).^2))...
                + mean((n_variety_c_new - n_variety_c).^2); 

            %% update for the next iteration         
            if diff>tol && iter<max_iter

                p1_m = weight*p1_m+(1-weight)*p1_m_new; % NN_m x 1
                p1_f = weight*p1_f+(1-weight)*p1_f_new; % 1 x 1

                p2_m = weight*p2_m+(1-weight)*p2_m_new; % NN_m x KM
                p2_f = weight*p2_f+(1-weight)*p2_f_new; % 1 x KM

                n_variety_c = weight*n_variety_c+(1-weight)*n_variety_c_new; % scalar, not an integer

                % make it an integer>=1
                n_variety = round(n_variety_c);
                if n_variety<1 
                    n_variety = 1;
                end
            end

        end

        if (iter>=max_iter)
            non_converge(mm)=1;
            for jj=1:NN_m
                ll=find(id_ist_m(jj)==id_ist_old);
                non_converge_all(ll)=1;
            end
            ll = find(id_ist_f==id_ist_old);
            non_converge_all(ll)=1;
        end

        % compute major profits which weren't calculated during the iterations
        % - rate stability regulation cost at (p1_f, p2_f)
        delta_p = p2_m-repmat(p1_m,1,KM); % NN_m x KM
        cost_rs_temp = cost2_0_m + 0.5*cost2_1_m.*(delta_p.^2);
        cost_rs_temp = cost_rs_temp(delta_p>0);
        cost_rs = zeros(size(delta_p)); % NN_m x KM
        cost_rs(delta_p>0) = cost_rs_temp;

        % - min loss ratio regulation cost
        p1_deviation = p1_m-p1_target_m;
        cost_ml = 0.5*cost1_e_m.*(p1_deviation.^2); % NN_m x 1

        % - profit 
        profit1 = B1_f*p1_m.*s1_mkt_m;

        if use_cost_weight==0
            profit2_k = (p2_m-D_claims2_m).*s2_mkt_m - cost_rs;
        else
            profit2_k = p2_m.*s2_mkt_m - D_claims_av_m - cost_rs;
        end

        profit2 = (beta_f^T1)*B2_f*...
            sum(repmat(pr_s_row, size(profit2_k,1),1).*profit2_k,2);
        Ecost_rs_m = (beta_f^T1)*B2_f*...
            sum(repmat(pr_s_row, size(profit2_k,1),1).*cost_rs,2);
        profit_m = profit1 + profit2 - cost_ml; % NN_m x 1

        % store convgerd values
        % insurer-level values: majors
        for jj=1:NN_m
            ll=find(id_ist_m(jj)==id_ist_old);
            p1_all(ll,:) = p1_m(jj,:);
            p2_all(ll,:) = p2_m(jj,:);
            solve_market_all(ll,:) = 1;
            fsolve_all(ll,:) = fsolve_m_new(jj,:);
            s1_mkt_all(ll,:) = s1_mkt_m(jj,:);
            s2_mkt_all(ll,:) = s2_mkt_m(jj,:);
            profit_all(ll,:) = profit_m(jj,:);
            rs_share_all(ll,:) = Ecost_rs_m(jj,:)./profit_m(jj,:);
        end

        % insurer-level values: save one row for fringe. these are all
        % scaled down to per-fringe values. 
        ll = find(id_ist_f==id_ist_old);
        p1_all(ll,:) = p1_f;
        p2_all(ll,:) = p2_f;
        solve_market_all(ll,:) = 1;
        fsolve_all(ll,:) = fsolve_f_new;
        s1_mkt_all(ll,:) = s1_mkt_f;
        s2_mkt_all(ll,:) = s2_mkt_f;
        profit_all(ll,:) = profit_f;
        rs_share_all(ll,:) = Ecost_rs_f./profit_f;

        % market-level values
        solve_market(mm) = 1;
        n_fringes_mkt(mm) = n_variety;
        enrollment(mm) = sum(s1_mkt_m) + n_variety*s1_mkt_f;
        enrollment_m(mm) = sum(s1_mkt_m);
        enrollment_f(mm) = n_variety*s1_mkt_f;
        enrollment_per_m(mm) = mean(s1_mkt_m);
        enrollment_per_f(mm) = s1_mkt_f;

        % we calculate consumer welfare using the correct p2 
        v_mm = [v_c v0_mm]; % N_y x (NN+1)
        exp_v_mm = exp(v_mm);

        % expected value in the market mm
        E_value_y(mm,:) = log(sum(exp_v_mm,2))';
        E_value(mm)     = sum(pr_y.*log(sum(exp_v_mm,2)));

    end

end   
















